
#include <new>
using namespace std;

#include "DataPacket.h"

#define ACCESS(data, i, j, width) data[i*width+j]

DataPacket::DataPacket(int width, int height, INT_PAIR ghost_cells) :
  width(width), height(height), ghost_cells(ghost_cells) {}

DataPacket::~DataPacket() {
  // for (int i=0; i<height+2*ghost_cells[1]; i++) 
  //   delete [] data[i];

  delete [] data;
}

bool DataPacket::init(const CONTAINER_TYPE *data_blk) {
  // Allocate data and initialize pointers
  if (!common_init())
    return false;

  // Copy the data block into the packet
  for (int i=0; i<height; i++) {
    memcpy(data+(i+ghost_cells[1])*new_w+ghost_cells[0], data_blk+i*width, \
	width*sizeof(CONTAINER_TYPE));
  }
  
  // Initialize ghost cells
  ghost_init();

  return true;
}

bool DataPacket::init(const CONTAINER_TYPE val) {
  // Allocate data and initialize pointers
  if (!common_init())
    return false;

  // A temporary data container initialized with 'val'
  vector<CONTAINER_TYPE> tmp_data(width, val);

  // Copy the value into the packet
  for (int i=0; i<height; i++) {
    memcpy(data+(i+ghost_cells[1])*new_w+ghost_cells[0], &tmp_data[0], \
	width*sizeof(CONTAINER_TYPE));
  }
  
  // Initialize ghost cells
  ghost_init();

  return true;
}

CONTAINER_TYPE& DataPacket::access_data(int i, int j) {
  return ACCESS(data, i, j, new_w);
}

void DataPacket::ghost_init() {
  // Populate ghost cells (row)
  for (int i=ghost_cells[1]; i<new_h-ghost_cells[1]; i++) {
    for (int j=0; j<ghost_cells[0]; j++) {
      // TODO: Replace with the ACCESS macro
      data[i*new_w+j] = data[i*new_w+ghost_cells[0]];
      data[i*new_w+width+ghost_cells[0]+j]= data[i*new_w+width+ghost_cells[0]-1];
    }
  }

  // Populate ghost cells (col)
  for (int i=0; i<ghost_cells[1]; i++)
    memcpy(&data[i*new_w+0], &data[ghost_cells[1]*new_w+0],
	new_w*sizeof(CONTAINER_TYPE));

  for (int i=ghost_cells[1]+width; i<new_h; i++)
    memcpy(&data[i*new_w+0], &data[(height+ghost_cells[1]-1)*new_w+0],
	new_w*sizeof(CONTAINER_TYPE));
}

bool DataPacket::common_init() {
  _l = _t = _r = _b = NULL;
  gl = gt = gr = gb = NULL;
  _ll = _ul = _lr = _ur = NULL;
  gll = gul = glr = gur = NULL;

  new_w = width+2*ghost_cells[0];
  new_h = height+2*ghost_cells[1];

  // Allocate data buffer for the DataPacket
  try {
    data = new CONTAINER_TYPE[new_w*new_h];
  }
  catch (bad_alloc &ba) {
    fprintf(stderr, "bad_alloc caught : %s\n", ba.what());

    return false;
  }

  int offset_x = ghost_cells[0];
  int offset_y = ghost_cells[1];

  // Intiailize the location pointers
  _l = &ACCESS(data, offset_y, offset_x, new_w);
  _t = _l;
  _r = &ACCESS(data, offset_y, width, new_w);
  _b = &ACCESS(data, height, offset_x, new_w);

  gl = &ACCESS(data, offset_y, 0, new_w);
  gt = &ACCESS(data, 0, offset_x, new_w);
  gr = &ACCESS(data, offset_y, width+1, new_w);
  gb = &ACCESS(data, height+1, offset_x, new_w);
  
  _ll = &ACCESS(data, height+offset_y-1, offset_x, new_w);
  _ul = &ACCESS(data, offset_y, offset_x, new_w);
  _lr = &ACCESS(data, height+offset_y-1, width+offset_x-1, new_w);
  _ur = &ACCESS(data, offset_y, width+offset_x-1, new_w);

  gll = &ACCESS(data, height+offset_y, 0, new_w);
  gul = &ACCESS(data, 0, 0, new_w);
  glr = &ACCESS(data, height+offset_y, width+offset_x, new_w);
  gur = &ACCESS(data, 0, width+offset_x, new_w);

  // Initialize MPI data types    
  if (MPI_Type_vector(height, 1, width+2*ghost_cells[0], MPI_FLOAT,
	&column_type)!=MPI_SUCCESS) {
    printf("MPI_Type_vector failed!!!\n");

    return false;
  }

  if (MPI_Type_commit(&column_type)!=MPI_SUCCESS) {
    printf("MPI_Type_commit failed!!!\n");

    return false;
  }
  
  if (MPI_Type_contiguous(width, MPI_FLOAT, &row_type)!=MPI_SUCCESS) {
    printf("MPI_Type_contiguous failed!!!\n");

    return false;
  }

  if (MPI_Type_commit(&row_type)!=MPI_SUCCESS) {
    printf("MPI_Type_commit failed!!!\n");

    return false;
  }

  return true;
}

bool DataPacket::interchange_lr(MPI_Comm comm, int rank) {
  int test_data_cnt = 0;
  MPI_Status status;

  // Exchange ghost cells, left-right, top-bottom, and along the corner cells.
  if (proc_map[1][0]!=-1&&proc_map[1][2]!=-1) { // Surrounded by two sides
    // Send to right, receive from left
    MPI_Sendrecv(_r, 1, column_type, proc_map[1][2], 123,
	gl, 1, column_type, proc_map[1][0], 123, comm, &status);

    MPI_Get_count(&status, column_type, &test_data_cnt);
    if (test_data_cnt==1) {
#if 0 
      printf("Rank %d: received %d column from proc %d\n", rank, test_data_cnt,
	  proc_map[1][0]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[1][0]);
      
      return false;
    }	

    // Send to left, receive from right
    MPI_Sendrecv(_l, 1, column_type, proc_map[1][0], 123,
	gr, 1, column_type, proc_map[1][2], 123, comm, &status);

    MPI_Get_count(&status, column_type, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d column from proc %d\n", rank, test_data_cnt,
	  proc_map[1][2]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[1][2]);
      
      return false;
    }
  }
  else if (proc_map[1][0]==-1&&proc_map[1][2]!=-1) { // One side 
    // Send to right, receive from right
    MPI_Sendrecv(_r, 1, column_type, proc_map[1][2], 123,
	gr, 1, column_type, proc_map[1][2], 123, comm, &status);

    MPI_Get_count(&status, column_type, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d column from proc %d\n", rank, test_data_cnt,
	  proc_map[1][2]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[1][2]);
      
      return false;
    }

    // Copy the left column to gl
    for (int i=0; i<height; i++) {
      int _offset_x = width+ghost_cells[0]; 
      CONTAINER_TYPE *_dst = gl+i*_offset_x;
      CONTAINER_TYPE *_src = _l+i*_offset_x;
      memcpy(_dst, _src, sizeof(CONTAINER_TYPE));
    }
  }
  else if (proc_map[1][0]!=-1&&proc_map[1][2]==-1) { // One side
    // Send to left, receive from left 
    MPI_Sendrecv(_l, 1, column_type, proc_map[1][0], 123,
	gl, 1, column_type, proc_map[1][0], 123, comm, &status);

    MPI_Get_count(&status, column_type, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d column from proc %d\n", rank, test_data_cnt,
	  proc_map[1][0]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[1][0]);
      
      return false;
    }
    
    // Copy the right column to gr
    for (int i=0; i<height; i++) {
      int _offset_x = width+ghost_cells[0]; 
      CONTAINER_TYPE *_dst = gr+i*_offset_x;
      CONTAINER_TYPE *_src = _r+i*_offset_x;
      memcpy(_dst, _src, sizeof(CONTAINER_TYPE));
    }
  }

  return true;
}

bool DataPacket::interchange_tb(MPI_Comm comm, int rank) {
  int test_data_cnt = 0;
  MPI_Status status;

  // Exchange ghost cells top-bottom.
  if (proc_map[0][1]!=-1&&proc_map[2][1]!=-1) { // Surrounded by two sides
    // Send to top, receive from bottom
    MPI_Sendrecv(_t, 1, row_type, proc_map[0][1], 123,
	gb, 1, row_type, proc_map[2][1], 123, comm, &status);

    MPI_Get_count(&status, row_type, &test_data_cnt);
    if (test_data_cnt==1) {
#if 0 
      printf("Rank %d: received %d row from proc %d\n", rank, test_data_cnt,
	  proc_map[2][1]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[2][1]);
      
      return false;
    }	

    // Send to bottom, receive from top
    MPI_Sendrecv(_b, 1, row_type, proc_map[2][1], 123,
	gt, 1, row_type, proc_map[0][1], 123, comm, &status);

    MPI_Get_count(&status, row_type, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d row from proc %d\n", rank, test_data_cnt,
	  proc_map[0][1]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[0][1]);
      
      return false;
    }
  }
  else if (proc_map[0][1]==-1&&proc_map[2][1]!=-1) { // One side 
    // Send to bottom, receive from bottom
    MPI_Sendrecv(_b, 1, row_type, proc_map[2][1], 123,
	gb, 1, row_type, proc_map[2][1], 123, comm, &status);

    MPI_Get_count(&status, row_type, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d row from proc %d\n", rank, test_data_cnt,
	  proc_map[2][1]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[2][1]);
      
      return false;
    }
    
    // Copy the top row to gt
    memcpy(gt, _t, width*sizeof(CONTAINER_TYPE));
  }
  else if (proc_map[0][1]!=-1&&proc_map[2][1]==-1) { // One side
    // Send to top, receive from top 
    MPI_Sendrecv(_t, 1, row_type, proc_map[0][1], 123,
	gt, 1, row_type, proc_map[0][1], 123, comm, &status);

    MPI_Get_count(&status, row_type, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d row from proc %d\n", rank, test_data_cnt,
	  proc_map[0][1]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[0][1]);
      
      return false;
    }
    
    // Copy the bottom row to gb
    memcpy(gb, _b, width*sizeof(CONTAINER_TYPE));
  }

  return true;
}

bool DataPacket::interchange_d0(MPI_Comm comm, int rank) {
  int test_data_cnt = 0;
  MPI_Status status;

  // Exchange ghost cells along diagonal 0.
  if (proc_map[0][0]!=-1&&proc_map[2][2]!=-1) { // Surrounded by two sides
    // Send to gul, receive from glr
    MPI_Sendrecv(_ul, 1, MPI_FLOAT, proc_map[0][0], 123,
	glr, 1, MPI_FLOAT, proc_map[2][2], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[2][2]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[2][2]);
      
      return false;
    }	

    // Send to glr, receive from gul 
    MPI_Sendrecv(_lr, 1, MPI_FLOAT, proc_map[2][2], 123,
	gul, 1, MPI_FLOAT, proc_map[0][0], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[0][0]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[0][0]);
      
      return false;
    }
  }
  else if (proc_map[0][0]==-1&&proc_map[2][2]!=-1) { // One side 
    // Send to glr, receive from glr
    MPI_Sendrecv(_lr, 1, MPI_FLOAT, proc_map[2][2], 123,
	glr, 1, MPI_FLOAT, proc_map[2][2], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[2][2]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[2][2]);
      
      return false;
    }
    
    // This ensures that the [row-1][col-1] accesses contain updated values
    gul = _ul;
  }
  else if (proc_map[0][0]!=-1&&proc_map[2][2]==-1) { // One side
    // Send to gul, receive from gul 
    MPI_Sendrecv(_ul, 1, MPI_FLOAT, proc_map[0][0], 123,
	gul, 1, MPI_FLOAT, proc_map[0][0], 123, comm, &status);
    
    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) {
#if 0 
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[0][0]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[0][0]);
      
      return false;
    }
    
    // This ensures that the [row+1][col+1] accesses contain updated values
    glr = _lr;
  }

  return true;
}

bool DataPacket::interchange_d1(MPI_Comm comm, int rank) {
  int test_data_cnt = 0;
  MPI_Status status;

  // Exchange ghost cells along diagonal 0.
  if (proc_map[0][2]!=-1&&proc_map[2][0]!=-1) { // Surrounded by two sides
    // Send to gur, receive from gll
    MPI_Sendrecv(_ur, 1, MPI_FLOAT, proc_map[0][2], 123,
	gll, 1, MPI_FLOAT, proc_map[2][0], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[2][0]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[2][0]);
      
      return false;
    }	

    // Send to gll, receive from gur 
    MPI_Sendrecv(_ll, 1, MPI_FLOAT, proc_map[2][0], 123,
	gur, 1, MPI_FLOAT, proc_map[0][2], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) {
#if 0 
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[0][2]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[0][2]);
      
      return false;
    }
  }
  else if (proc_map[0][2]==-1&&proc_map[2][0]!=-1) { // One side 
    // Send to gll, receive from gll
    MPI_Sendrecv(_ll, 1, MPI_FLOAT, proc_map[2][0], 123,
	gll, 1, MPI_FLOAT, proc_map[2][0], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) {
#if 0 
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[2][0]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[2][0]);
      
      return false;
    }
  }
  else if (proc_map[0][2]!=-1&&proc_map[2][0]==-1) { // One side
    // Send to gur, receive from gur 
    MPI_Sendrecv(_ur, 1, MPI_FLOAT, proc_map[0][2], 123,
	gur, 1, MPI_FLOAT, proc_map[0][2], 123, comm, &status);

    MPI_Get_count(&status, MPI_FLOAT, &test_data_cnt);
    if (test_data_cnt==1) { 
#if 0
      printf("Rank %d: received %d cell from proc %d\n", rank, test_data_cnt,
	  proc_map[0][2]); 
#endif
    }
    else {
      printf("Rank %d: error receiving data from proc %d\n", rank, proc_map[0][2]);
      
      return false;
    }
  }

  return true;
}

bool DataPacket::interchange_co(MPI_Comm comm, int rank) {
  return (interchange_d0(comm, rank)&interchange_d1(comm, rank));
}

bool DataPacket::interchange(MPI_Comm comm, int rank) {
  return (interchange_lr(comm, rank)&interchange_tb(comm,
	rank)&interchange_co(comm, rank));
}
